{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MQLsAJC01D5m",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Train a model with Episodic Training\n",
    "Episodic training has attracted a lot of interest in the early years of Few-Shot Learning research. Some papers still use it, and refer to it as \"meta-learning\".\n",
    "\n",
    "Recent works distinguish the Few-Shot Classifier from the training framework, so as from v1.0 of EasyFSL, methods to episodically train a classifier were taken out of the logic of the FewShotClassifier class. Instead, we provide in this notebook an example of how to perform episodic training on a few-shot classifier.\n",
    "\n",
    "Use it, copy it, change it, get crazy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "o5aFVwdT1D5t",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Getting started\n",
    "First we're going to do some imports (this is not the interesting part)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u-qkFjp9h4Mh",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    import google.colab\n",
    "    colab = True\n",
    "except:\n",
    "    colab = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "QTwWMJerh5yF",
    "outputId": "26dc469a-3b9f-46eb-950e-ce1d322ad126",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "if colab is True:\n",
    "    # Running in Google Colab\n",
    "    # Clone the repo\n",
    "    !git clone https://github.com/sicara/easy-few-shot-learning\n",
    "    %cd easy-few-shot-learning\n",
    "    !pip install .\n",
    "else:\n",
    "    # Run locally\n",
    "    # Ensure working directory is the project's root\n",
    "    # Make sure easyfsl is installed!\n",
    "    %cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "A1LKM7vX1D5v",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "from pathlib import Path\n",
    "import random\n",
    "from statistics import mean\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jkgsfhkD1D5z",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).\n",
    "\n",
    "I strongly encourage that you do this in **all your scripts**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ocLk4O-G1D50",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "random_seed = 0\n",
    "np.random.seed(random_seed)\n",
    "torch.manual_seed(random_seed)\n",
    "random.seed(random_seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OtQokWas1D51",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Then we're gonna set the shape of our problem.\n",
    "\n",
    "Also we define our set-up, like the device (change it if you don't have CUDA) or the number of workers for data loading."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_0rB2h7C1D52",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "n_way = 5\n",
    "n_shot = 5\n",
    "n_query = 10\n",
    "\n",
    "DEVICE = \"cuda\"\n",
    "n_workers = 12"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "g-PT7nk81D53",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Training\n",
    "\n",
    "First we define our data loaders for training and validation. You can see that I chose tu use CUB in this notebook, because it's a small dataset, so we can have good results quite quickly. We use `CUB` and `TaskSampler` which are built-in objects from EasyFSL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NfPcFd3wiAp5",
    "outputId": "ec11a6fe-84bd-4138-977b-b2e9fd1c5692",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Download the CUB dataset\n",
    "!make download-cub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1T7446uT1D54",
    "outputId": "0de2ffe7-fb9a-405a-d73f-8d802f6e7313",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from easyfsl.datasets import CUB\n",
    "from easyfsl.samplers import TaskSampler\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "n_tasks_per_epoch = 500\n",
    "n_validation_tasks = 100\n",
    "\n",
    "# Instantiate the datasets\n",
    "train_set = CUB(split=\"train\", training=True)\n",
    "val_set = CUB(split=\"val\", training=False)\n",
    "\n",
    "# Those are special batch samplers that sample few-shot classification tasks with a pre-defined shape\n",
    "train_sampler = TaskSampler(\n",
    "    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch\n",
    ")\n",
    "val_sampler = TaskSampler(\n",
    "    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks\n",
    ")\n",
    "\n",
    "# Finally, the DataLoader. We customize the collate_fn so that batches are delivered\n",
    "# in the shape: (support_images, support_labels, query_images, query_labels, class_ids)\n",
    "train_loader = DataLoader(\n",
    "    train_set,\n",
    "    batch_sampler=train_sampler,\n",
    "    num_workers=n_workers,\n",
    "    pin_memory=True,\n",
    "    collate_fn=train_sampler.episodic_collate_fn,\n",
    ")\n",
    "val_loader = DataLoader(\n",
    "    val_set,\n",
    "    batch_sampler=val_sampler,\n",
    "    num_workers=n_workers,\n",
    "    pin_memory=True,\n",
    "    collate_fn=val_sampler.episodic_collate_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CfX-JEEd1D55",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "And then we define the network. Here I chose Prototypical Networks and the built-in ResNet18 from PyTorch because it's easy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7gW9FSUf1D55",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from easyfsl.methods import PrototypicalNetworks, FewShotClassifier\n",
    "from easyfsl.modules import resnet12\n",
    "\n",
    "\n",
    "convolutional_network = resnet12()\n",
    "few_shot_classifier = PrototypicalNetworks(convolutional_network).to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KaDd3CVm1D56",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).\n",
    "\n",
    "We're also gonna use a TensorBoard because it's always good to see what your training curves look like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "d2EZR3yA1D56",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torch.optim import SGD, Optimizer\n",
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "LOSS_FUNCTION = nn.CrossEntropyLoss()\n",
    "\n",
    "n_epochs = 200\n",
    "scheduler_milestones = [120, 160]\n",
    "scheduler_gamma = 0.1\n",
    "learning_rate = 1e-2\n",
    "tb_logs_dir = Path(\".\")\n",
    "\n",
    "train_optimizer = SGD(\n",
    "    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4\n",
    ")\n",
    "train_scheduler = MultiStepLR(\n",
    "    train_optimizer,\n",
    "    milestones=scheduler_milestones,\n",
    "    gamma=scheduler_gamma,\n",
    ")\n",
    "\n",
    "tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xTgldYnI1D57",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "And now let's get to it! Here we define the function that performs a training epoch.\n",
    "\n",
    "We use tqdm to monitor the training in real time in our logs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fyLOQ2A21D58",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def training_epoch(\n",
    "    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer\n",
    "):\n",
    "    all_loss = []\n",
    "    model.train()\n",
    "    with tqdm(\n",
    "        enumerate(data_loader), total=len(data_loader), desc=\"Training\"\n",
    "    ) as tqdm_train:\n",
    "        for episode_index, (\n",
    "            support_images,\n",
    "            support_labels,\n",
    "            query_images,\n",
    "            query_labels,\n",
    "            _,\n",
    "        ) in tqdm_train:\n",
    "            optimizer.zero_grad()\n",
    "            model.process_support_set(\n",
    "                support_images.to(DEVICE), support_labels.to(DEVICE)\n",
    "            )\n",
    "            classification_scores = model(query_images.to(DEVICE))\n",
    "\n",
    "            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            all_loss.append(loss.item())\n",
    "\n",
    "            tqdm_train.set_postfix(loss=mean(all_loss))\n",
    "\n",
    "    return mean(all_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gMyYb6pB1D58",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "And we have everything we need! To perform validations we'll just use the built-in `evaluate` function from `easyfsl.methods.utils`.\n",
    "\n",
    "This is now the time to **start training**.\n",
    "\n",
    "I added something to log the state of the model that gave the best performance on the validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from easyfsl.utils import evaluate\n",
    "\n",
    "\n",
    "best_state = few_shot_classifier.state_dict()\n",
    "best_validation_accuracy = 0.0\n",
    "for epoch in range(n_epochs):\n",
    "    print(f\"Epoch {epoch}\")\n",
    "    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)\n",
    "    validation_accuracy = evaluate(\n",
    "        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix=\"Validation\"\n",
    "    )\n",
    "\n",
    "    if validation_accuracy > best_validation_accuracy:\n",
    "        best_validation_accuracy = validation_accuracy\n",
    "        best_state = copy.deepcopy(few_shot_classifier.state_dict())\n",
    "        # state_dict() returns a reference to the still evolving model's state so we deepcopy\n",
    "        # https://pytorch.org/tutorials/beginner/saving_loading_models\n",
    "        print(\"Ding ding ding! We found a new best model!\")\n",
    "\n",
    "    tb_writer.add_scalar(\"Train/loss\", average_loss, epoch)\n",
    "    tb_writer.add_scalar(\"Val/acc\", validation_accuracy, epoch)\n",
    "\n",
    "    # Warn the scheduler that we did an epoch\n",
    "    # so it knows when to decrease the learning rate\n",
    "    train_scheduler.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Yay we successfully performed Episodic Training! Now if you want to you can retrieve the best model's state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "few_shot_classifier.load_state_dict(best_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Evaluation\n",
    "\n",
    "Now that our model is trained, we want to test it.\n",
    "\n",
    "First step: we fetch the test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "n_test_tasks = 1000\n",
    "\n",
    "test_set = CUB(split=\"test\", training=False)\n",
    "test_sampler = TaskSampler(\n",
    "    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks\n",
    ")\n",
    "test_loader = DataLoader(\n",
    "    test_set,\n",
    "    batch_sampler=test_sampler,\n",
    "    num_workers=n_workers,\n",
    "    pin_memory=True,\n",
    "    collate_fn=test_sampler.episodic_collate_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Second step: we run the few-shot classifier on the test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)\n",
    "print(f\"Average accuracy : {(100 * accuracy):.2f} %\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Congrats! You performed Episodic Training using EasyFSL. If you want to compare with a model trained using classical training, look at [this other example notebook](classical_training.ipynb).\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "episodic_training.ipynb",
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}